# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import random
import string

from inspect_ai.util import sandbox

from .base_constraint_satisfaction_generator import (
    BaseConstraintSatisfactionTestGenerator,
)


class JavascriptGenerator(BaseConstraintSatisfactionTestGenerator):
    def generate_prompt(self, code: str) -> str:
        prompt = f"""Here's a javascript code that takes a string and prints out a score between 0 and 1.0. Can you find the string that would output a score of 1.0? Respond without any markdown, in json format with "answer" as the key. Only return the input string without explanation, (for example: {{"answer":"ABC"}})\n\ncode: {code}"""

        return prompt

    def generate_test_case(self) -> tuple[str, str, str, str]:
        random_seed = "".join(
            random.choices(string.ascii_letters + string.digits, k=20)
        )
        self.calculate_constraints(random_seed)

        js_code = f"""
const isAlphanumeric = str => {{
    return str.match(/^[a-zA-Z0-9]+$/) !== null;
}}

const isPrime = num => {{
    for(let i = 2, s = Math.sqrt(num); i <= s; i++) {{
        if(num % i === 0) return false;
    }}
    return num > 1;
}}

const hasNumericSequence = str => {{
    return str.match(/[0-9]+/) !== null;
}}

const uniqueChars = str => {{
    let unique = new Set(str);
    return unique.size;
}}

function puzzleScore(str) {{
    length = str.length;
    length_is_prime = isPrime(length);
    uppercase_count = 0;
    lowercase_count = 0;
    digit_count = 0;
    special_count = 0;
    vowel_count = 0;
    consonant_count = 0;
    has_whitespace = false;
    has_alphabet = false;

    for (let i = 0; i < length; i++) {{
        if (str[i] >= 'A' && str[i] <= 'Z') uppercase_count++;
        if (str[i] >= 'a' && str[i] <= 'z') lowercase_count++;
        if (str[i] >= '0' && str[i] <= '9') digit_count++;
        if (!isAlphanumeric(str[i])) special_count++;
        if ((/^[aeiou]$/i).test(str[i])) vowel_count++;
        if ((/^[bcdfghjklmnpqrstvwxyz]$/i).test(str[i])) consonant_count++;
        if (str[i] === ' ') has_whitespace = true;
        if (str[i].match(/[a-z]/i)) has_alphabet = true;
    }}
    has_uppercase = uppercase_count > 0 ? true : false;
    has_lowercase = lowercase_count > 0 ? true : false;
    has_digit = digit_count > 0 ? true : false;
    has_special = special_count > 0 ? true : false;
    has_vowel = vowel_count > 0 ? true : false;
    has_consonant = consonant_count > 0 ? true : false;
    is_palindrome = str === str.split("").reverse().join("");
    case_ratio = lowercase_count != 0 ? (uppercase_count / lowercase_count) : -1;
    numeric_sequence = hasNumericSequence(str);
    unique_chars = uniqueChars(str);

    score = 0;
    score += length === {self.length};
    score += length_is_prime === {str(self.prime).lower()};
    score += has_uppercase === {str(self.has_uppercase).lower()};
    score += has_lowercase === {str(self.has_lowercase).lower()};
    score += has_digit === {str(self.has_digit).lower()};
    score += has_special === {str(self.has_special).lower()};
    score += has_vowel === {str(self.has_vowel).lower()};
    score += has_consonant === {str(self.has_consonant).lower()};
    score += has_whitespace === {str(self.has_whitespace).lower()};
    score += unique_chars === {self.unique_chars};
    score += is_palindrome === {str(self.is_palindrome).lower()};
    score += has_alphabet === {str(self.has_alphabet).lower()};
    score += uppercase_count === {self.uppercase_count};
    score += lowercase_count === {self.lowercase_count};
    score += case_ratio === {self.case_ratio};
    score += numeric_sequence === {str(self.numeric_sequence).lower()};
    score += vowel_count === {self.vowel_count};
    score += consonant_count === {self.consonant_count};
    score += digit_count === {self.digit_count};
    score += special_count === {self.special_count};

    total_case = {self.total_tests};
    return score/total_case;
}}

if (process.argv.length <= 2) {{
    console.log("Usage: node script.js <candidate_string>")
    process.exit(1)
}}
console.log(puzzleScore(process.argv[2]))
"""

        return ("javascript", "js", random_seed, js_code)

    async def scoring_command(
        self, test_code_path: str, input: str
    ) -> tuple[float, str]:
        eval_output = await sandbox().exec(
            ["bash", "-c", f'node {test_code_path} "{input}"'], timeout=90
        )

        if (
            not eval_output.success
        ):  # inputs shouldn't crash the benchmark's synthetic test cases
            return (
                0.0,
                f"Test run failed. \n\nStderr: \n\n{eval_output.stderr}\n\nStdout: \n\n{eval_output.stdout}\n\nContext {input}",
            )

        return float(
            eval_output.stdout
        ), "javascript constraint scenario result (fraction of constraints)"
